{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import RobertaTokenizer, RobertaTokenizerFast\n",
    "from transformers import RobertaForSequenceClassification\n",
    "from transformers import RobertaConfig, RobertaModel, RobertaForMaskedLM\n",
    "from transformers import LineByLineTextDataset, DataCollatorForLanguageModeling\n",
    "from transformers import Trainer, TrainingArguments\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading: 100%|██████████| 1.33G/1.33G [02:07<00:00, 11.2MB/s]\n"
     ]
    }
   ],
   "source": [
    "maskedlm_model = RobertaForMaskedLM.from_pretrained(\"roberta-large\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-large\", do_lower_case=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['CO', 'VID']\n",
      "['cor', 'on', 'av', 'irus']\n"
     ]
    }
   ],
   "source": [
    "print(tokenizer.tokenize('COVID'))\n",
    "print(tokenizer.tokenize('coronavirus'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "50265\n",
      "50269\n"
     ]
    }
   ],
   "source": [
    "print (len(tokenizer)) \n",
    "tokenizer.add_tokens([\"COVID\"]) \n",
    "tokenizer.add_tokens([\"Covid\"]) \n",
    "tokenizer.add_tokens([\"covid\"]) \n",
    "tokenizer.add_tokens([\"coronavirus\"])\n",
    "print (len(tokenizer)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['COVID']\n",
      "['Covid']\n",
      "['covid']\n",
      "['coronavirus']\n"
     ]
    }
   ],
   "source": [
    "print(tokenizer.tokenize('COVID'))\n",
    "print(tokenizer.tokenize('Covid'))\n",
    "print(tokenizer.tokenize('covid'))\n",
    "print(tokenizer.tokenize('coronavirus'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['COVID', 'Ġ', 'Covid', 'Ġ', 'covid', 'Ġand', 'Ġ', 'coronavirus', 'Ġare', 'Ġbad']\n"
     ]
    }
   ],
   "source": [
    "test_sentence = \"COVID Covid covid and coronavirus are bad\"\n",
    "print(tokenizer.tokenize(test_sentence))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# FOR CORONAVIRUS:\n",
    "maskedlm_model.resize_token_embeddings(len(tokenizer)) \n",
    "random_vector = maskedlm_model.get_input_embeddings().weight[-1].detach().numpy() # last one is covid random vector\n",
    "\n",
    "plt.title(\"Randomly Initialized Vector\")\n",
    "plt.hist(random_vector, bins=50)\n",
    "plt.show()\n",
    "\n",
    "pandemic_id = tokenizer.convert_tokens_to_ids(\"pandemic\")\n",
    "virus_id = tokenizer.convert_tokens_to_ids(\"virus\")\n",
    "respiratory_id = tokenizer.convert_tokens_to_ids(\"respiratory\")\n",
    "virus_embedding = maskedlm_model.get_input_embeddings().weight[virus_id]\n",
    "pandemic_embedding = maskedlm_model.get_input_embeddings().weight[pandemic_id]\n",
    "respiratory_embedding = maskedlm_model.get_input_embeddings().weight[respiratory_id]\n",
    "mean_embedding = torch.mean(torch.stack([virus_embedding, respiratory_embedding, pandemic_embedding]), dim=0)\n",
    "maskedlm_model.get_input_embeddings().weight[-1].data[:] = mean_embedding\n",
    "mean_vector = maskedlm_model.get_input_embeddings().weight[-1].detach().numpy()\n",
    "\n",
    "plt.title(\"Custom Initialized Vector\")\n",
    "plt.hist(mean_vector, bins=50)\n",
    "plt.show()\n",
    "\n",
    "# For COVID:\n",
    "maskedlm_model.resize_token_embeddings(len(tokenizer)) \n",
    "random_vector = maskedlm_model.get_input_embeddings().weight[-2].detach().numpy() # last one is covid random vector\n",
    "\n",
    "mean_embedding = torch.mean(torch.stack([virus_embedding, respiratory_embedding, pandemic_embedding]), dim=0)\n",
    "maskedlm_model.get_input_embeddings().weight[-2].data[:] = mean_embedding\n",
    "\n",
    "# For Covid:\n",
    "maskedlm_model.resize_token_embeddings(len(tokenizer)) \n",
    "random_vector = maskedlm_model.get_input_embeddings().weight[-3].detach().numpy() # last one is covid random vector\n",
    "\n",
    "mean_embedding = torch.mean(torch.stack([virus_embedding, respiratory_embedding, pandemic_embedding]), dim=0)\n",
    "maskedlm_model.get_input_embeddings().weight[-3].data[:] = mean_embedding\n",
    "\n",
    "# For covid:\n",
    "maskedlm_model.resize_token_embeddings(len(tokenizer)) \n",
    "random_vector = maskedlm_model.get_input_embeddings().weight[-4].detach().numpy() # last one is covid random vector\n",
    "\n",
    "mean_embedding = torch.mean(torch.stack([virus_embedding, respiratory_embedding, pandemic_embedding]), dim=0)\n",
    "maskedlm_model.get_input_embeddings().weight[-4].data[:] = mean_embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = LineByLineTextDataset(\n",
    "    tokenizer=tokenizer,\n",
    "    file_path=\"pretraining_tweets_en_full_clean.txt\",\n",
    "    block_size=32,\n",
    ")\n",
    "\n",
    "data_collator = DataCollatorForLanguageModeling(\n",
    "    tokenizer=tokenizer, mlm=True, mlm_probability=0.4\n",
    ")\n",
    "\n",
    "training_args = TrainingArguments(   \n",
    "    output_dir=\"./\",\n",
    "    overwrite_output_dir=True,\n",
    "    learning_rate=3e-05, \n",
    "    num_train_epochs=3,\n",
    "    per_gpu_train_batch_size=32,\n",
    "    save_steps=10000,\n",
    "    #save_total_limit=2,\n",
    ")\n",
    "\n",
    "# Set up trainer\n",
    "trainer = Trainer(\n",
    "    model=maskedlm_model,\n",
    "    args=training_args,\n",
    "    data_collator=data_collator,\n",
    "    train_dataset= dataset\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_model(\"covid_roberta_40\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.save_pretrained(\"covid_roberta_40\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
